import torch
import torch.nn as nn
import numpy as np
import copy
from torch.nn.parameter import Parameter
import math
import torch.nn.functional as F


class ConvUnit1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel, stride=1, padding=0, nonlinearity=nn.LeakyReLU(0.2)):
        super(ConvUnit1d, self).__init__()
        self.model = nn.Sequential(
                     nn.Conv1d(in_channels, out_channels, kernel, stride, padding), nonlinearity)

    def forward(self, x):
        return self.model(x)


class ConvUnitTranspose1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel, stride=1, padding=0, out_padding=0, nonlinearity=nn.LeakyReLU(0.2)):
        super(ConvUnitTranspose1d, self).__init__()
        self.model = nn.Sequential(
                     nn.ConvTranspose1d(in_channels, out_channels, kernel, stride, padding), nonlinearity)

    def forward(self, x):
        return self.model(x)


class LinearUnit(nn.Module):
    def __init__(self, in_features, out_features, nonlinearity=nn.LeakyReLU(0.2)):
        super(LinearUnit, self).__init__()
        self.model = nn.Sequential(
                     nn.Linear(in_features, out_features), nonlinearity)

    def forward(self, x):
        return self.model(x)


class EncX(nn.Module):
    def __init__(self, enc_dim, n=38, T=20):
        super(EncX, self).__init__()
        self.n = n
        self.T = T
        self.conv_dim = enc_dim

        k0, k1, k2 = 3, 2, 2
        s0, s1, s2 = 3, 2, 2
        p0, p1, p2 = 0, 0, 0
        s_d = (int(self.n / (k0 * k1 * k2)))

        self.k0, self.k1, self.k2 = k0, k1, k2
        self.s0, self.s1, self.s2 = s0, s1, s2
        self.p0, self.p1, self.p2 = p0, p1, p2
        self.cd = [32, s_d]
        self.conv = nn.Sequential(
            ConvUnit1d(1, 8, kernel=self.k0,
                       stride=self.s0,
                       padding=self.p0),
            ConvUnit1d(8, 16, kernel=self.k1,
                       stride=self.s1,
                       padding=self.p1),
            ConvUnit1d(16, 32, kernel=self.k2,
                       stride=self.s2,
                       padding=self.p2)
        )

        self.conv_fc = nn.Sequential(
            LinearUnit(self.cd[0] * self.cd[1], self.conv_dim * 2),
            LinearUnit(self.conv_dim * 2, self.conv_dim))

    def enc_x(self, x):
        x = x.view(-1, 1, self.n)
        x = self.conv(x)
        x = x.view(-1, self.cd[0] * self.cd[1])
        x = self.conv_fc(x)
        x = x.view(-1, self.T, self.conv_dim)
        return x

    def forward(self, x):
        x_hat = self.enc_x(x)
        return x_hat


class DecX(nn.Module):
    def __init__(self, enc_dim, T=20, nonlinear=nn.Tanh()):
        super(DecX, self).__init__()
        self.conv_dim = enc_dim
        self.dec_init_dim = enc_dim
        self.T = T

        if self.conv_dim == 36:
            k0, k1, k2 = 3, 2, 2
            s0, s1, s2 = 3, 2, 2
            p0, p1, p2 = 0, 0, 0
            s_d = (int(self.conv_dim / (k0 * k1 * k2)))
        elif self.conv_dim == 38:
            k0, k1, k2 = 2, 2, 2
            s0, s1, s2 = 2, 2, 2
            p0, p1, p2 = 1, 0, 0
            s_d = (int((self.conv_dim + s0 * p0) / (k0 * k1 * k2)))
        elif self.conv_dim == 25:
            k0, k1, k2 = 3, 3, 2
            s0, s1, s2 = 3, 3, 1
            p0, p1, p2 = 1, 0, 0
            s_d = 2
        elif self.conv_dim == 55:
            k0, k1, k2 = 3, 4, 2
            s0, s1, s2 = 3, 3, 2
            p0, p1, p2 = 1, 0, 0
            s_d = 3
        else:
            raise ValueError('Invalid kpi numbers: please choose from the set [36,38,25,55]')
        self.k0, self.k1, self.k2 = k0, k1, k2
        self.s0, self.s1, self.s2 = s0, s1, s2
        self.p0, self.p1, self.p2 = p0, p1, p2
        self.cd = [32, s_d]
        self.deconv_fc_mu = nn.Sequential(
            LinearUnit(self.dec_init_dim, self.conv_dim * 2),
            LinearUnit(self.conv_dim * 2, self.cd[0] * self.cd[1]))
        self.deconv_mu = nn.Sequential(
            ConvUnitTranspose1d(32, 16, kernel=self.k2,
                                stride=self.s2,
                                padding=self.p2),
            ConvUnitTranspose1d(16, 8, kernel=self.k1,
                                stride=self.s1,
                                padding=self.p1),
            ConvUnitTranspose1d(8, 1, kernel=self.k0,
                                stride=self.s0,
                                padding=self.p0,
                                nonlinearity=nonlinear)
        )
        self.deconv_fc_logsigma = nn.Sequential(
            LinearUnit(self.dec_init_dim, self.conv_dim * 2),
            LinearUnit(self.conv_dim * 2, self.cd[0] * self.cd[1]))
        self.deconv_logsigma = nn.Sequential(
            ConvUnitTranspose1d(32, 16, kernel=self.k2,
                                stride=self.s2,
                                padding=self.p2),
            ConvUnitTranspose1d(16, 8, kernel=self.k1,
                                stride=self.s1,
                                padding=self.p1),
            ConvUnitTranspose1d(8, 1, kernel=self.k0,
                                stride=self.s0,
                                padding=self.p0,
                                nonlinearity=nonlinear)
        )

    def dec_x_mu(self, x):
        x = self.deconv_fc_mu(x)
        x = x.view(-1, self.cd[0], self.cd[1])
        x = self.deconv_mu(x)
        x = x.view(-1, 1, 1, self.conv_dim, 1)
        return x

    def dec_x_logsigma(self, x):
        x = self.deconv_fc_logsigma(x)
        x = x.view(-1, self.cd[0], self.cd[1])
        x = self.deconv_logsigma(x)
        x = x.view(-1, 1, 1, self.conv_dim, 1)
        return x

    def forward(self, x):
        x_mu = self.dec_x_mu(x)
        x_logsigma = self.dec_x_logsigma(x)
        return x_mu, x_logsigma


class LossFunctions:
    eps = 1e-8

    def log_normal(self, x, mu, var):
        if self.eps > 0.0:
            var = var + self.eps
        return -0.5 * torch.sum(
            np.log(2.0 * np.pi) + torch.log(var) + torch.pow(x - mu, 2) / var)


class ReparameterizeTrick:
    def reparameterize_gaussian(self, mean, logvar, random_sampling=True):
        if random_sampling is True:
            eps = torch.randn_like(logvar)
            std = torch.exp(0.5 * logvar)
            z = mean + eps * std
            return z
        else:
            return mean


class GenerationNet(nn.Module):
    def __init__(self, h_dim, z_dim, x_dim, win_len):
        super(GenerationNet, self).__init__()
        self.h_dim = h_dim
        self.z_dim = z_dim
        self.x_dim = x_dim
        self.win_len = win_len

        self.rt = ReparameterizeTrick()
        self.Pz_h_mean_forward = nn.Sequential(LinearUnit(h_dim, h_dim), LinearUnit(h_dim, z_dim))
        self.Pz_h_logvar_forward = nn.Sequential(LinearUnit(h_dim, h_dim), LinearUnit(h_dim, z_dim))
        self.Deconv_z = DecX(z_dim, win_len, nonlinear=nn.LeakyReLU(0.2))
        self.Px_logvar_forward = nn.Sequential(LinearUnit(h_dim, h_dim), LinearUnit(h_dim, x_dim))

    def Pz_prior(self, h, alpha1, gamma2):
        z_mean_prior_forward = None
        z_logvar_prior_forward = None

        for t in range(self.win_len):
            h_t = h[:, t, :]
            Whu = F.softmax(torch.mm(alpha1, gamma2), dim=0)
            z_mean_prior_forward_t = F.leaky_relu(torch.mm(h_t, Whu.t()), 0.2)
            # z_mean_prior_forward_t = self.Pz_h_mean_forward(h_t)
            z_logvar_prior_forward_t = self.Pz_h_logvar_forward(h_t)

            if z_mean_prior_forward is None:
                z_mean_prior_forward = z_mean_prior_forward_t.unsqueeze(1)
                z_logvar_prior_forward = z_logvar_prior_forward_t.unsqueeze(1)
            else:
                z_mean_prior_forward = torch.cat((z_mean_prior_forward, z_mean_prior_forward_t.unsqueeze(1)), dim=1)
                z_logvar_prior_forward = torch.cat((z_logvar_prior_forward, z_logvar_prior_forward_t.unsqueeze(1)),
                                                   dim=1)

        return z_mean_prior_forward, z_logvar_prior_forward

    def gen_px_hz(self, h, z, alpha0, gamma0, gamma1):
        x_mu = None
        x_logsigma = None
        for t in range(self.win_len):
            h_t = h[:, t, :]
            z_posterior_forward_t = z[:, t, :]
            z_posterior_forward_t, _ = self.Deconv_z(z_posterior_forward_t)
            z_posterior_forward_t = z_posterior_forward_t.squeeze(1).squeeze(1).squeeze(-1)
            Wzx = F.softmax(torch.mm(alpha0, gamma1), dim=0)
            Whx = F.softmax(torch.mm(alpha0, gamma0), dim=0)
            x_mu_t = F.tanh(torch.mm(h_t, Whx.t()) + torch.mm(z_posterior_forward_t, Wzx.t()))
            x_mu_t = x_mu_t.unsqueeze(1).unsqueeze(1).unsqueeze(-1)
            x_logsigma_t = F.tanh(self.Px_logvar_forward(h_t))
            x_logsigma_t = x_logsigma_t.unsqueeze(1).unsqueeze(1).unsqueeze(-1)

            if x_mu is None:
                x_mu = x_mu_t
                x_logsigma = x_logsigma_t
            else:
                x_mu = torch.cat((x_mu, x_mu_t), dim=1)
                x_logsigma = torch.cat((x_logsigma, x_logsigma_t), dim=1)

        return x_mu, x_logsigma

    def forward(self, h, z_posterior_forward, alpha1, alpha0, gamma0, gamma1, gamma2):
        z_mean_prior_forward, z_logvar_prior_forward = self.Pz_prior(h, alpha1, gamma2)
        x_mu, x_logsigma = self.gen_px_hz(h, z_posterior_forward, alpha0, gamma0, gamma1)

        return z_mean_prior_forward, z_logvar_prior_forward, x_mu, x_logsigma


class InferenceNet(nn.Module):
    def __init__(self, c_dim, z_dim, x_dim, h_dim, embd_h, layer_xz, layer_h, n_head, vocab_len, dropout,
                 q_len, win_len, device=torch.device('cpu'), is_train=True):
        super(InferenceNet, self).__init__()
        self.z_dim = z_dim
        self.x_dim = x_dim
        self.h_dim = h_dim
        self.embd_h = embd_h
        self.layer_xz = layer_xz
        self.layer_h = layer_h
        self.n_head = n_head
        self.vocab_len = vocab_len
        self.dropout = dropout
        self.q_len = q_len
        self.win_len = win_len
        self.device = device
        self.c_dim = c_dim
        self.is_train = is_train

        self.rt = ReparameterizeTrick()
        self.TransZ = VariationalGraphTransformer(n_time_series=z_dim, n_head=n_head, num_layer=layer_xz,
                                                  n_embd=z_dim, vocab_len=vocab_len, dropout=dropout, q_len=q_len,
                                                  win_len=win_len, scale_att=False, use_gcn=True, device=device)
        self.TransX = VariationalGraphTransformer(n_time_series=x_dim, n_head=n_head, num_layer=layer_xz,
                                                  n_embd=x_dim, vocab_len=vocab_len, dropout=dropout, q_len=q_len,
                                                  win_len=win_len, scale_att=False, use_gcn=True, device=device)
        self.TransH = VariationalGraphTransformer(n_time_series=x_dim+z_dim, n_head=n_head, num_layer=layer_h,
                                                  n_embd=embd_h, vocab_len=vocab_len, dropout=dropout, q_len=q_len,
                                                  win_len=win_len, scale_att=False, use_gcn=False, device=device)

        self.encX = EncX(c_dim, x_dim, win_len)
        self.Pz_xh_mean_forward = nn.Sequential(LinearUnit(c_dim, h_dim), LinearUnit(h_dim, z_dim))
        self.Pz_xh_logvar_forward = nn.Sequential(LinearUnit(c_dim, h_dim), LinearUnit(h_dim, z_dim))
        self.proj_h = nn.Sequential(LinearUnit(embd_h + x_dim + z_dim, h_dim), LinearUnit(h_dim, h_dim))

    def infer_qz(self, x, alpha1, alpha0):
        xtmp = self.encX(x.float())
        z_mean_posterior_forward = self.Pz_xh_mean_forward(xtmp)
        z_logvar_posterior_forward = self.Pz_xh_logvar_forward(xtmp)
        z_posterior_forward = self.rt.reparameterize_gaussian(z_mean_posterior_forward,
                                                              z_logvar_posterior_forward, self.is_train)
        z = z_posterior_forward
        x_squeeze = x.squeeze(2).squeeze(-1).float()
        ztmp = self.TransZ(z, alpha1)
        xtmp = self.TransX(x_squeeze, alpha0)
        htmp = self.TransH(torch.cat((xtmp, ztmp), dim=-1))
        h = self.proj_h(htmp)

        return z_posterior_forward, z_mean_posterior_forward, z_logvar_posterior_forward, h

    def forward(self, x, alpha1, alpha0):
        z_posterior_forward, z_mean_posterior_forward, z_logvar_posterior_forward, h_out = \
            self.infer_qz(x, alpha1, alpha0)
        return z_posterior_forward, z_mean_posterior_forward, z_logvar_posterior_forward, h_out


class Attention(nn.Module):
    def __init__(self, n_head, n_embd, win_len, scale, q_len, attn_pdrop=0.1, resid_pdrop=0.1, use_gcn=False):
        super(Attention, self).__init__()
        mask = torch.tril(torch.ones(win_len, win_len)).view(1, 1, win_len, win_len)
        self.register_buffer('mask_tri', mask)
        self.use_gcn = use_gcn
        self.n_head = n_head
        self.split_size = n_embd * self.n_head
        self.scale = scale
        self.q_len = q_len
        self.query_key = nn.Conv1d(n_embd, n_embd * n_head * 2, self.q_len)
        if use_gcn:
            self.value = Conv1D(n_embd * n_head, 0, n_embd)
            self.c_proj = Pooling(n_head)
        else:
            self.value = Conv1D(n_embd * n_head, 1, n_embd)
            self.c_proj = Conv1D(n_embd, 1, n_embd * self.n_head)
        self.attn_dropout = nn.Dropout(attn_pdrop)
        self.resid_dropout = nn.Dropout(resid_pdrop)

    def attn(self, query: torch.Tensor, key, value: torch.Tensor):
        activation = torch.nn.Softmax(dim=-1)
        pre_att = torch.matmul(query, key)
        if self.scale:
            pre_att = pre_att / math.sqrt(value.size(-1))
        mask = self.mask_tri[:, :, :pre_att.size(-2), :pre_att.size(-1)]
        pre_att = pre_att * mask + -1e9 * (1 - mask)
        pre_att = activation(pre_att)
        pre_att = self.attn_dropout(pre_att)
        attn = torch.matmul(pre_att, value)

        return attn

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)
        if k:
            return x.permute(0, 2, 3, 1)
        else:
            return x.permute(0, 2, 1, 3)

    def forward(self, x):
        value = self.value(x)
        qk_x = nn.functional.pad(x.permute(0, 2, 1), pad=(self.q_len - 1, 0))
        query_key = self.query_key(qk_x).permute(0, 2, 1)
        query, key = query_key.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
        attn = self.attn(query, key, value)
        if not self.use_gcn:
            attn = self.merge_heads(attn)
        attn = self.c_proj(attn)
        attn = self.resid_dropout(attn)
        return attn


class Pooling(nn.Module):
    def __init__(self, n_head):
        super(Pooling, self).__init__()
        self.pool = nn.AvgPool1d(kernel_size=n_head)

    def forward(self, x):
        B, H, T, P = x.size()
        x = x.view(B, H, -1).permute(0, 2, 1)
        x = self.pool(x).squeeze().view(B, T, P)
        return x


class Conv1D(nn.Module):
    def __init__(self, out_dim, rf, in_dim):
        super(Conv1D, self).__init__()
        self.rf = rf
        self.out_dim = out_dim
        self.in_dim = in_dim
        if rf == 1:
            w = torch.empty(in_dim, out_dim)
            nn.init.normal_(w, std=0.02)
            self.w = Parameter(w)
            self.b = Parameter(torch.zeros(out_dim))

    def forward(self, x):
        if self.rf == 1:
            size_out = x.size()[:-1] + (self.out_dim,)
            x = torch.addmm(self.b, x.view(-1, x.size(-1)), self.w)
            x = x.view(*size_out)
        else:
            size_out = x.size()[:-1] + (self.out_dim,)
            x = x.unsqueeze(-2).repeat(1, 1, self.out_dim // self.in_dim, 1)
            x = x.view(*size_out)
        return x


class LayerNorm(nn.Module):
    def __init__(self, n_embd, e=1e-5):
        super(LayerNorm, self).__init__()
        self.g = nn.Parameter(torch.ones(n_embd))
        self.b = nn.Parameter(torch.zeros(n_embd))
        self.e = e

    def forward(self, x):
        mu = x.mean(-1, keepdim=True)
        sigma = (x - mu).pow(2).mean(-1, keepdim=True)
        x = (x - mu) / torch.sqrt(sigma + self.e)
        return self.g * x + self.b


class MLP(nn.Module):
    def __init__(self, n_state, n_embd):
        super(MLP, self).__init__()
        n_embd = n_embd
        self.c_fc = Conv1D(n_state, 1, n_embd)
        self.c_proj = Conv1D(n_embd, 1, n_state)
        self.act = nn.ReLU()
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        hidden1 = self.act(self.c_fc(x))
        hidden2 = self.c_proj(hidden1)
        return self.dropout(hidden2)


class Block(nn.Module):
    def __init__(self, n_head, win_len, n_embd, scale, q_len, use_gcn):
        super(Block, self).__init__()
        self.use_gcn = use_gcn
        self.attn = Attention(n_head, n_embd, win_len, scale, q_len, use_gcn=use_gcn)
        self.ln_1 = LayerNorm(n_embd)
        self.mlp = MLP(4 * n_embd, n_embd)
        self.ln_2 = LayerNorm(n_embd)
        if use_gcn:
            self.linear_map = nn.Linear(n_embd, n_embd)

    def forward(self, x, y):
        attn = self.attn(x)
        ln1 = self.ln_1(x + attn)
        if self.use_gcn:
            mlp = F.softmax(F.relu(torch.mm(y, y.t())), dim=0).t()
            mlp = self.linear_map(torch.matmul(ln1, mlp))
        else:
            mlp = self.mlp(ln1)
        hidden = self.ln_2(ln1 + mlp)
        return hidden


class SelfDefinedTransformer(nn.Module):
    def __init__(self, n_time_series, n_head, num_layer, n_embd, vocab_len, win_len, dropout, scale_att,
                 q_len, use_gcn, device=torch.device('cpu')):
        super(SelfDefinedTransformer, self).__init__()
        self.input_dim = n_time_series
        self.n_head = n_head
        self.num_layer = num_layer
        self.n_embd = n_embd
        self.vocab_len = vocab_len
        self.win_len = win_len
        self.dropout = dropout
        self.scale_att = scale_att
        self.q_len = q_len
        self.use_gcn = use_gcn
        self.device = device
        if use_gcn:
            assert n_time_series == n_embd
            self.po_embed = nn.Embedding(vocab_len, n_time_series)
            block = Block(n_head, vocab_len, n_time_series, scale=scale_att, q_len=q_len, use_gcn=use_gcn)
        else:
            self.po_embed = nn.Embedding(vocab_len, n_embd)
            block = Block(n_head, vocab_len, n_time_series + n_embd, scale=scale_att, q_len=q_len, use_gcn=use_gcn)
        self.blocks = nn.ModuleList([copy.deepcopy(block) for _ in range(num_layer)])
        nn.init.normal_(self.po_embed.weight, std=0.02)

    def forward(self, x, y):
        batch_size = x.size(0)
        length = x.size(1)
        if self.use_gcn:
            embedding_sum = torch.zeros(batch_size, length, self.input_dim).to(self.device)
        else:
            embedding_sum = torch.zeros(batch_size, length, self.n_embd).to(self.device)
        position = torch.tensor(torch.arange(length), dtype=torch.long).to(self.device)
        po_embedding = self.po_embed(position)
        embedding_sum[:] = po_embedding
        if self.use_gcn:
            x = x + embedding_sum
        else:
            x = torch.cat((x, embedding_sum), dim=2)
        for block in self.blocks:
            x = block(x, y)
        return x


class VariationalGraphTransformer(nn.Module):
    def __init__(self, n_time_series: int, n_head: int, num_layer: int, n_embd: int, vocab_len: int, dropout: float,
                 q_len: int, win_len: int, scale_att: bool = False, use_gcn: bool = False, device=torch.device('cpu')):
        super(VariationalGraphTransformer, self).__init__()
        self.transformer = SelfDefinedTransformer(n_time_series, n_head, num_layer, n_embd, vocab_len, win_len,
                                                  dropout, scale_att, q_len, use_gcn, device)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x: torch.Tensor, y: torch.Tensor = None):
        h = self.transformer(x, y)
        return h
